import os
import csv
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt


def parse_args():
    p = argparse.ArgumentParser(description="Vol2 Tick-Chain Double-Flip Monte Carlo")
    p.add_argument("--seed", type=int, default=20250818, help="RNG seed for reproducibility (A3)")
    p.add_argument("--trials", type=int, default=10_000, help="number of Monte Carlo trials")
    p.add_argument("--N", type=int, default=5, help="max radius (state space is -N..+N)")
    return p.parse_args()


def set_seed(seed: int | None):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)


def run(seed: int, n_trials: int, N: int):
    set_seed(seed)

    # Discrete state space
    r_vals = np.arange(-N, N + 1)

    # --- Operators (mass-preserving with boundary accumulation) ---
    # F: shift toward more negative r (r -> r-1), keep mass that would fall off at -N
    def F(M: np.ndarray) -> np.ndarray:
        new = np.zeros_like(M)
        new[:-1] = M[1:]     # shift left
        new[0]  += M[0]      # boundary at -N keeps its own mass
        return new

    # S: shift toward more positive r (r -> r+1), keep mass that would fall off at +N
    def S(M: np.ndarray) -> np.ndarray:
        new = np.zeros_like(M)
        new[1:] = M[:-1]     # shift right
        new[-1] += M[-1]     # boundary at +N keeps its own mass
        return new

    # B: collapse all mass to r = 0
    def B(M: np.ndarray) -> np.ndarray:
        total = M.sum()
        collapsed = np.zeros_like(M)
        collapsed[r_vals == 0] = total
        return collapsed
    # ---------------------------------------------------------------

    def sample_r0() -> int:
        return random.choice(r_vals)

    results = []
    for _ in range(n_trials):
        # initialize delta spike at r0
        M0 = np.zeros_like(r_vals, dtype=float)
        r0 = sample_r0()
        M0[r_vals == r0] = 1.0

        # outward flips + collapse
        M_out1  = F(M0)
        M_out2  = F(M_out1)
        M_out2B = B(M_out2)

        # inward flips + collapse
        M_in1   = S(M0)
        M_in2   = S(M_in1)
        M_in2B  = B(M_in2)

        # record positions after each step
        results.append([
            r0,
            r_vals[M_out1.argmax()],
            r_vals[M_out2.argmax()],
            r_vals[M_out2B.argmax()],
            r_vals[M_in1.argmax()],
            r_vals[M_in2.argmax()],
            r_vals[M_in2B.argmax()],
        ])

    os.makedirs("results", exist_ok=True)

    # write CSV
    with open("results/tick_chain_results.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["r0", "Out1", "Out2", "Out2B", "In1", "In2", "In2B"])
        writer.writerows(results)

    # plot histograms (PNG may include timestamps; use CSV for A3 hashing)
    data = np.array(results)
    plt.figure(figsize=(8, 6))
    labels = ["r0", "Out1", "Out2", "Out2B", "In1", "In2", "In2B"]
    for idx, label in enumerate(labels):
        plt.hist(data[:, idx], bins=len(r_vals), alpha=0.5, label=label, density=True)
    plt.legend()
    plt.tight_layout()
    plt.savefig("results/tick_chain_plot.png")

    print(f"[A3] ran with seed={seed}, trials={n_trials}, N={N}")
    print('CSV written to results/tick_chain_results.csv')


if __name__ == "__main__":
    args = parse_args()
    run(seed=args.seed, n_trials=args.trials, N=args.N)
